#############################################################################
# CODE FOR GFE by period
# - Grouped slopes and free intercepts, constant within period
# - Need original data and specification of periods or number of periods
#
#############################################################################

#############################################################################
# Wrapper function to call for estimation
#############################################################################

tgfe_wrap = function(data,state_year_list,np,G,niter,tol,Ns,sseed=1,save_name,se_gamma_multi=1,se_beta_multi=1) {
  
  t0=Sys.time()
  dt = copy(data)
  
  #######################################
  # Matrix of periods
  #######################################
  
  y0=1950
  y1=2005
  
  dy = floor((y1-y0+1)/np)
  lv = rep(dy,np)
  for (j in seq(1,np-1)) {
    if (sum(lv)<(y1-y0+1)) {
      lv[j]=lv[j]+1
    }
  }
  per = matrix(y0-1,nrow=1,ncol=2)
  for (j in seq(1,np,1)) {
    pern = matrix(c(per[j,2]+1,per[j,2]+1+lv[j]-1),nrow=1,ncol=2)
    per = rbind(per,pern)
  }
  per = per[-1,]
  per
  per[,2]-per[,1]
  rm(dy,lv,j,pern,y0,y1)
  
  #######################################
  # Add period 
  #######################################
  
  # Add period
  for (j in seq(1,np,1)) {
    dt[year>= per[j,1] & year<=per[j,2],period:=j]
  }
  
  #######################################
  # Starting values
  #######################################
  
  ## Generate starting values
  start0 = start_FE(data=dt,state_year_list=state_year_list,per=per,np=np,G=G) 
  
  #######################################
  # Demean data per period
  #######################################
  
  # Demean data
  cols_old=c("yield","ddr","precr",state_year_list)
  dt_dm = dt[,.SD-lapply(.SD,mean),.SDcols=cols_old,by=c("fips","period")]  
  dt_dm[,year:=dt$year]
  
  #######################################
  # Iterate 
  #######################################
  
  save_pos = paste0("Data_new/GFE_period/Gfe_res1_",save_name,".RData")
  res = tgroup_FE(data= dt_dm,per=per,np=np,G=G,
                  beta0 = start0$beta0, gamma0 = start0$gamma0,
                  se_beta = start0$se_beta*se_beta_multi, se_gamma = start0$se_gamma*se_gamma_multi,
                  Ns=Ns,niter=niter,tol=tol,sseed=sseed, save_name=save_pos,save_every=200)
  
  # Check that for the end-year of each period, state-year FE is set to zero
  aux = res$gamma[1,]
  nn = paste0("sy_",state_list,"_")
  for (j in seq(1,np,1)) {
    nny = paste0(nn,per[j,2])
    print(summary(aux[nny]))
  }
  
  #######################################
  # Table of marginal effects
  #######################################
  
  dc = group_fips(data=dt,per=per,np=np,G=G,res=res)
  dim(dc)
  summary(dc)
  
  # Order groups in decreasing order of beta values
  dc = data_groups_order(dc=dc,res=res,np=np,G=G)

  #######################################
  # Return
  #######################################
  
  t1=Sys.time()
  time_dif=t1-t0
  return(list(res=res,dc=dc,time_dif=time_dif,per=per))
  
}

#############################################################################
# Function to generate overall starting values from FE
#############################################################################

start_FE = function(data,state_year_list,per,np,G) {
  
  dt0 = copy(data)
  
  # Re-estimate FE model (overall)
  form=as.formula(paste("yield ~  precr + ddr:factor(fips) +",paste(state_year_list, collapse= "+")," | factor(fips)"))
  form
  fe = feols(fml=form,data=dt0,cluster="state")
  fe_all =fe
  
  # Save common weather and state-year FE coefficients
  gamma0=fe$coefficients[c("precr",state_year_list)]
  gamma0
  
  # Save the slopes 
  dt_slop =slopes_dt(f=fe,names_var="")
  dim(dt_slop)
  summary(dt_slop)
  
  # Use the SE as a baseline var for the looping over starting values
  se_beta = sd(dt_slop$coef)
  se_gamma = fe$coeftable[names(gamma0),"Std. Error"] 

  ### For slopes, re-estimate by FE per period to get different grouped starting values per period
  beta0_all = c()
  for (j in seq(1,np,1)) {
    
    # Subset corresponding period
    dt = copy(dt0)
    dt = dt[year>= per[j,1] & year<=per[j,2]]
    
    # Re-estimate FE in this period
    fe = feols(fml=form,data=dt,cluster="state")
    fe$collin.var
    
    # Save slopes 
    dt_slop =slopes_dt(f=fe,names_var="")
    dim(dt_slop)
    summary(dt_slop)
    
    # Discretize the slopes
    qq= quantile(dt_slop$coef,probs=c(seq(0,1,length.out=G+1)))
    qq
    for (g in seq(1,G,1)) {
      dt_slop[coef >= qq[g],q:=g]
    }
    beta0=c()
    for (g in seq(1,G,1)) {
      beta0=c(beta0,mean(dt_slop[q==g,coef]))
    }
    rm(qq,g)
    beta0 = setNames(beta0,paste0(paste0("ddr_g",seq(1,G,1)),paste0("_p",j)))
    beta0
    beta0_all = c(beta0_all,beta0)
    
  }
  
  # Review beta0
  beta0_all
  
  return(list(beta0=beta0_all,gamma0=gamma0,se_beta=se_beta,se_gamma=se_gamma))  
}



#############################################################################
# Grouped FE for a given starting value 
#############################################################################

tgroup_FE_s = function(data,per,np,G,beta0,gamma0,niter=1e4,tol=1e-7) {
  
  #
  dtm=copy(data)

  # Vector of objective functions to return
  obj_vec=c()
  
  y = dtm$yield
  Z = as.matrix(dtm[,.SD,.SDcols=names(gamma0)])
  X = as.matrix(dtm$ddr,ncol=1)  

  # Generate names of the group-varying slopes:  g1_p1, g2_p1, .. gG_p1, g1_p2...
  #beta_names = paste0(rep(paste0("ddr_g",seq(1,G,1)),np),"_p",sort(rep(1:np,G)))
  beta_names = names(beta0)
  
  # Iterate
  k=0
  obj=sum(y^2)
  dch =1e8
  gamma_now = gamma0
  beta_now =beta0

  while (k<=niter & dch>tol) {
  
    ### Step 1: assignment step
    resg =data.table(fips=dtm$fips,year=dtm$year,period=dtm$period)
    
    # Paste the slopes relevant for each period, one column per group
    for (g in seq(1,G,1)) {
      for (j in seq(1,np,1)) {
        nbeta = paste0("ddr_g",g,"_p",j)
        resg[period==j,nb := beta_now[nbeta]]
      }
      setnames(resg,"nb",paste0("beta_g",g))
    }
    
    # Generate vector of residuals per i,t
    for (g in seq(1,G,1)) {
      bn = as.matrix(resg[,.SD,.SDcols=paste0("beta_g",g)],ncol=1)
      resg[,rnow := (y - bn*X-Z%*%gamma_now)^2]
      setnames(resg,"rnow",paste0("res",g))
    }
    
    # Average per i over t
    resg = resg[,lapply(.SD,sum),by="fips"]
    
    # Choose group assignment as the minimum (or lower g if ties)  
    resg[,group:=1]
    resg[,res_min:=res1]
    for (g in seq(2,G,1)) {
      resg[,res_now:=resg[,.SD,.SDcols=paste0("res",g)]]
      resg[res_now<res_min,group:=g]
      resg[res_now<res_min,res_min:=res_now]
    }    
  
    # Paste the chosen group to the original data 
    dt_now = merge(dtm,resg[,.(fips,group)],by="fips",all.x=TRUE)

    ### Step 2: update step
    # Generate dummy of group x ddr x period
    for (g in seq(1,G,1)) {
      for (j in seq(1,np,1)) {
        dt_now[,ddr_now := (group==g)*ddr*(period==j)]
        setnames(dt_now,"ddr_now",paste0("ddr_g",g,"_p",j))
      }
    }
    
    # Regress
    form = as.formula(paste("yield ~ ", paste(names(beta0),collapse= "+"),"+",
                            paste(names(gamma0), collapse= "+"),"-1"))
    reg_now = feols(fml=form,data=dt_now)
    reg_now$collin.var

    # Update coefficients
    gamma_now=reg_now$coefficients[names(gamma0)]
    gamma_now[is.na(gamma_now)==TRUE]=0
    gamma_now=setNames(gamma_now,names(gamma0))
    summary(gamma_now)
    beta_now=reg_now$coefficients[names(beta0)]
    #beta_now[is.na(beta_now)==TRUE]=0
    beta_now = setNames(beta_now,names(beta0))

    # Objective value
    obj_now = reg_now$ssr
    obj_vec=c(obj_vec,obj_now)
    dch = abs(obj-obj_now)
    obj = obj_now
    k=k+1 
    
  }
  
  return(list(iter=k,obj=obj_vec,beta=beta_now,gamma=gamma_now,group=resg[,.SD,.SDcols=c("fips","group")]))
}

#############################################################################
# Grouped FE pp across starting values randomly generated
#############################################################################

tgroup_FE = function(data,per,np,G,beta0,gamma0,se_beta,se_gamma,Ns,niter,tol,sseed=1,save_name,save_every=500) {

  dtm = copy(data)

  # Set seed
  set.seed(sseed)
  
  # Matrices to save results
  R_obj <<- c()
  R_iter <<- c()
  R_gamma <<- matrix(NA,nrow=1,ncol=length(gamma0)) 
  R_beta <<- matrix(NA,nrow=1,ncol=length(beta0)) 
  R_groups <<- matrix(sort(unique(dtm$fips)),ncol=1)
    
  # For the starting value given
  res = tgroup_FE_s(data=dtm,per=per,np=np,G=G,beta0=beta0,gamma0=gamma0,niter=niter,tol=tol)
  R_obj <<- c(R_obj,res$obj[length(res$obj)])
  R_gamma <<- rbind(R_gamma,res$gamma)
  R_beta <<- rbind(R_beta,res$beta)
  R_groups <<- cbind(R_groups,res$group$group)
  R_iter <<- cbind(R_iter ,res$iter)
  
  # Loop over simulated random starts
  for (k in seq(1,Ns-1,1)) {
    
    print(paste0("Starting value ",k))

    # Starting point 
    beta_now = beta0 + rnorm(n=G*np,mean=0,sd=se_beta)
    gamma_now = as.numeric(gamma0 +  rmvnorm(n=1, mean = rep(0,length(gamma0)), sigma = diag(se_gamma^2)))
    names(gamma_now)=names(gamma0)
    names(beta_now)=names(beta0)
    
    # Grouped FE at this starting point
    gfe = tgroup_FE_s(data=dtm,per=per,np=np,G=G,beta0=beta_now,gamma0=gamma_now,niter=niter,tol=tol)
    
    # Save results
    R_obj <<- c(R_obj,gfe$obj[length(gfe$obj)])
    R_iter <<- c(R_iter,gfe$iter)
    R_gamma <<- rbind(R_gamma,gfe$gamma)
    R_beta <<- rbind(R_beta,gfe$beta)
    R_groups <<- cbind(R_groups,gfe$group$group)
    
    if (k%%save_every==0) {
      save.image(save_name)
    }
    
  }  
    
  # 
  R_gamma <<- R_gamma[-1,]
  R_beta <<- R_beta[-1,]
  
  # Sort all results from best to worse in R_obj
  conv=as.numeric(R_iter[which.min(R_obj)]<niter)
  S = data.table(obs=R_obj)
  S[,pos:=.I]
  setkey(S,obs)
  pos =S$pos
  R_obj=R_obj[pos]
  R_iter=R_iter[pos]
  R_gamma=R_gamma[pos,]
  R_beta=R_beta[pos,]
  R_sub = R_groups[,-1]
  R_sub = R_sub[,pos]
  R_groups = cbind(R_groups[,1],R_sub)
  R_groups =data.table(R_groups)
  setnames(R_groups,c("fips",paste0("V",seq(1,Ns,1))))
  
  return(list(conv=conv,obj=R_obj,gamma=R_gamma,beta=R_beta,groups=R_groups,iter=R_iter))
  
}

#############################################################################
# Function that returns a data set of county-year with the group and mg
#############################################################################

group_fips = function(data,per,np,G,res,K=1) {

  dtn = copy(data)
  
  # Generate a data base per county with the group assignment
  dc = res$groups[,.SD,.SDcols=c("fips","V1")]
  setnames(dc,c("fips","group"))
  dtn = merge(dtn[,.SD,.SDcols=c("fips","year","corn_area","period","ddr")],dc,by="fips",all.x=TRUE)
  
  # Add to the groups the relevant coefficients
  for (g in seq(1,G,1)) {
    for (j in seq(1,np,1)) {
      nname = paste0("ddr_g",g,"_p",j)
      dtn[period==j & group==g, beta:=res$beta[1,nname]]
    }
  }
   
  # Generate marginal effect
  dtn[,mg:=beta]

  # Reorder groups in increasing order of beta for period K
  beta =res$beta[1,]
  oo = data.table(beta=beta[((K-1)*G+1):(K*G)])
  oo[,pos:=.I]  
  setkey(oo,beta)
  pos = oo$pos

  # Reorder groups
  for (g in seq(1,G,1)) {
    dtn[group==pos[g],group_new:=g]
  }
  dtn[,group:=NULL]
  setnames(dtn,"group_new","group") 
  
  return(dtn)
}

#############################################################################
# Function that plots densities of marginal effects 
# 1. Per group across periods 
# 2. Per period across groups
# Groups are ordered
#############################################################################

plot_density_2m = function(data,dc_fe,np,per,G,res,save_name,xxlim=c(-45,20),yylim=c(0,0.30)) {
  
  # Add period to the data  
  dtn = copy(data)
  for (j in seq(1,np,1)) {
    dtn[year>= per[j,1] & year<=per[j,2],period:=j]
  }
  
  # Generate data of county-period  marginal effects (across groups)
  dtp = group_fips(data=dtn,per=per,np=np,G=G,res=res,K=np)  
  
  # Count the number of periods a county appears
  dtp[,ones:=1]
  dtp[,count:=cumsum(ones),by=c("fips","period")]
  dtp[,ones:=as.numeric(count==1)]
  dtp[,count:=sum(ones),by="fips"]

  # Plot 1: histogram of marginal effects, colors per group
  tot_area =  sum(dtp$corn_area)
  g = ggplot(data=dtp,aes(x=mg,weights=corn_area/tot_area,color=factor(group),fill=factor(group))) +theme_bw() +
    geom_bar(alpha=0.4,position="dodge")  +
    theme(title=element_text(size=14),plot.title=element_text(hjust=0.5),
          plot.subtitle=element_text(hjust=0.5), axis.text = element_text(size=18),
          axis.title = element_text(size=18),
          legend.position = 'right',legend.text = element_text(size=14)) +
    scale_color_discrete(name= "Group") + scale_fill_discrete(name="Group") +
    xlab("Temperature mg effect") + ylab("Density") + xlim(xxlim) + ylim(yylim)
  print(g)
  ggsave(paste0(save_res,"/Histogram_mg_bygroup_",save_name,".pdf"),width=6.3,height=3.9,units="in")
  
  # Plot 2: histogram of marginal effects, colors per group
  g = ggplot(data=dtp,aes(x=mg,weights=corn_area/tot_area,color=factor(period),fill=factor(period))) +theme_bw() +
    geom_bar(alpha=0.4,position="dodge")  +
    theme(title=element_text(size=14),plot.title=element_text(hjust=0.5),
          plot.subtitle=element_text(hjust=0.5), axis.text = element_text(size=18),
          axis.title = element_text(size=18),
          legend.position = 'right',legend.text = element_text(size=14)) +
    scale_color_discrete(name= "Period") + scale_fill_discrete(name="Period") +
    xlab("Temperature mg effect") + ylab("Density") + xlim(xxlim) + ylim(yylim)
  print(g)
  ggsave(paste0(save_res,"/Histogram_mg_byperiod_",save_name,".pdf"),width=6.3,height=3.9,units="in")

  return(dtp=dtp)
}


#############################################################################
# Function that plots densities of temperature across groups
# Groups are ordered
#############################################################################

plot_density_ddr = function(data,np,G,save_name,xxlim=c(-45,20),yylim=c(0,0.30)) {
  
  dtp=copy(data)
  
  # Plot 1: density of temperature across periods per group
  g=ggplot(data=dtp,aes(x=ddr,weights=corn_area,color=factor(group))) + geom_density()  + theme_bw() +
    theme(title=element_text(size=14),plot.title=element_text(hjust=0.5),
          plot.subtitle=element_text(hjust=0.5), axis.text = element_text(size=18),
          axis.title = element_text(size=18),
          legend.position = 'right',legend.text = element_text(size=14)) +
    scale_color_discrete(name= "Groups", breaks=) +
    xlab("Temperature") + ylab("Density") + xlim(xxlim) + ylim(yylim)
  print(g)
  ggsave(paste0(save_res,"/Density_ddr_bygroups_",save_name,".pdf"),width=6.3,height=3.9,units="in")

  # Plot 1: density of temperature across groups per period
  g=ggplot(data=dtp,aes(x=ddr,weights=corn_area,color=factor(period))) + geom_density()  + theme_bw() +
    theme(title=element_text(size=14),plot.title=element_text(hjust=0.5),
          plot.subtitle=element_text(hjust=0.5), axis.text = element_text(size=18),
          axis.title = element_text(size=18),
          legend.position = 'right',legend.text = element_text(size=14)) +
    scale_color_discrete(name= "Period", breaks=) +
    xlab("Temperature") + ylab("Density") + xlim(xxlim) + ylim(yylim)
  print(g)
  ggsave(paste0(save_res,"/Density_ddr_byperiod_",save_name,".pdf"),width=6.3,height=3.9,units="in")
  
    
}

#############################################################################
# Maps of marginal effects or groups
#############################################################################

plot_map = function(data,bmin=0,bmax=0,trunc=0,save_name="",llabel="Marginal effect of temperature",dir=1) {
  
  # Data: weighted average at the county level
  dt_groups=copy(data)
  setnames(dt_groups,c("fips","year","corn_area","GFE"))
  dt_groups = dt_groups[,lapply(.SD,wtd.mean,weights=corn_area),by=c("fips"),.SDcols=c("GFE")]

  if (dir==1) {
    myPalette <- colorRampPalette((brewer.pal(11, "Spectral")))
  }
  if (dir==-1) {
    myPalette <- colorRampPalette(rev(brewer.pal(11, "Spectral")))
  }
  
  # Merge data with map
  dt_plot = merge(dt_map,dt_groups,by="fips",all.x=T)
  summary(dt_plot)
  
  # Print as a check the number of counties not in the map
  print(length(unique(dt_groups$fips))-length(unique(dt_plot[is.na(GFE)==FALSE,fips])))
  
  # Truncate values to bmin and bmax
  if (trunc!=0) {
    dt_plot[GFE< bmin,GFE:=bmin]
    dt_plot[GFE> bmax,GFE:=bmax]
  }  
  if (trunc==0) {
    bmax = max(abs(min(dt_groups$GFE)),abs(max(dt_groups$GFE)))
    bmin=-bmax
  }
  
  # Plot
  g=ggplot(data=dt_plot[long>-100],aes(x=long,y=lat,group=group, fill=GFE)) +
    geom_polygon(color = "gray90", size = 0.1) +
    coord_map(projection = "albers", lat0 = 45, lat1 = 55) +
    labs(fill=llabel) +
    theme(legend.position="bottom",
          axis.line=element_blank(),
          axis.text=element_blank(),
          axis.ticks=element_blank(),
          axis.title=element_blank(),
          panel.background=element_blank(),
          panel.border=element_blank(),
          panel.grid=element_blank(),
          legend.key.width=unit(0.8,"cm"))
  
  sc = scale_fill_gradientn(colours = myPalette(100), limits=c(bmin,bmax))
  g = g+sc
  print(g)
  if (nchar(save_name)>1) {
    ggsave(paste0(save_res,save_name,".pdf"))
  }
}


#############################################################################
# Save latex table of coefficients
#############################################################################

print_coef = function(res,np,per,G) {
  
  # Beta coefficients: row is a group, column is a period
  C =matrix(res$beta[1,],nrow=G,ncol=np)
  
  # Add proportions per group
  C = cbind(table(res$groups[,V1])/nrow(res$groups),C)
  
  # Reorder groups increasingly according to last column
  oo = data.table(C[,np+1])
  oo[,pos:=.I]  
  setkey(oo,V1)
  pos = oo$pos
  C = C[pos,]
  
  # Names of row and columns
  colnames(C)=c("Proportion",paste0("Period ",1:np))
  rownames(C)=paste0("$\\beta$ group ",1:G)  
  
  # Save latex table
  dig=rep(3,dim(C)[2]+1)
  print(xtable(C,type="latex",digits=dig,display=c("s",rep("f",dim(C)[2]))),hline.after = NULL,
        file=paste0(save_res,"/tGFE_coef_g",G,"_np",np,".tex"),include.rownames = TRUE,include.colnames=TRUE,
        sanitize.text.function = function(x){x},only.contents = TRUE)
}


#############################################################################
# Function that reorders group variable in dc data as increasing 
# in beta from a given period
#############################################################################

data_groups_order = function(dc,res,np,G) {
  
  # Beta coefficients: row is a group, column is a period
  C =matrix(res$beta[1,],nrow=G,ncol=np)
  
  # Reorder groups increasingly according to last column
  oo = data.table(C[,np])
  oo[,pos:=.I]  
  setkey(oo,V1)
  pos = oo$pos

  dc_now=copy(dc)
  for (j in seq(1,G)) {
    dc_now[group==pos[j],group_new:=j]
  }
  dc_now[,group:=NULL]

  return(dc_now)
}


#############################################################################
# Function that returns a data set of mg effects per county for the 
# 10 best starting values
#############################################################################

group_fips_10b = function(G,data,res) {
  
  #
  dtm=data[,.SD,.SDcols=c("fips","year","period","corn_area","ddr")]
  dt_tot=copy(dtm)
  
  # Loop over results 
  for (k in seq(1,10,1)) {
    
    # Generate a data base per county with the group assignment
    dc = res$groups[,.SD,.SDcols=c("fips",paste0("V",k))]
    setnames(dc,c("fips","group"))
    dc = merge(dc,dtm,by="fips",all=TRUE)
    
    # Add to the groups the relevant coefficients
    for (g in seq(1,G,1)) {
      for (p in seq(1,3,1)) {
        dc[,paste0("beta_g",g,"_p",p):=res$beta[k,paste0("ddr_g",g,"_p",p)]]
      }
    }
    
    # Generate beta of each county
    dc[,beta:=0]
    for (p in seq(1,3,1)) {
      for (g in seq(1,G,1)) {
        dc[,beta_now:=.SD,.SDcols=paste0("beta_g",g,"_p",p)]
        dc[is.na(beta_now)==TRUE,beta_now:=0]
        dc[,beta:=beta+beta_now*(group==g)*(period==p)]
      }
    }
    dc[,beta_now:=NULL]

    # Generate marginal effect
    dc[,mg:=beta]
    
    # Add to dt_tot
    dt_tot = merge(dt_tot,dc[,.(fips,year,mg)],by=c("fips","year"),all=TRUE)
    setnames(dt_tot,"mg",paste0("mg_b",k))
    
  }  
  return(dt_tot)
}


